import torch
import torch.nn as nn

from multiprocessing import Pool
from math import ceil

from ..layers import DLRTModule


class DLRT_Optimizer:

    def __init__(self, NN: nn.Module, absolute=False,
                 KLS_optim=None, baseline=False, n_workers=1, **kwargs):

        """
        initializer for the dlr_opt (dynamical low rank optimizer) class.
        INPUTS:
        NN: neural network with custom layers, methods and attributes needed (look at Lenet5 for an example) 
        kwargs : learning rate (integration step) and other optional arguments like momentum
        tau : tolerance for singular values
        absolute : flag variable, True if tau has to be interpreted as an absolute tolerance  
        
        Every optimization step call (e.g. K_preprocess_step()) calls the omonimous method on the correspoding layer
        """

        self.NN = NN
        self.absolute = absolute
        self.kw = dict(kwargs)
        self.KLS_optim = KLS_optim
        self.baseline = baseline
        self.n_workers = n_workers
        self.multi_thread = (self.n_workers != 1)
        self.S_grad_cum = []
        args_integrator = dict([(k, v) for k, v in self.kw.items() if k not in {'wd'}])

        if self.KLS_optim is not None:

            self.integrator = self.KLS_optim(self.NN.lr_model.parameters(), **args_integrator)

        else:

            self.integrator = torch.optim.SGD(self.NN.lr_model.parameters(), **args_integrator)
            # self.integrator = torch.optim.Adam(self.NN.lr_model.parameters())

        self.preprocess_step()

    # preprocess steps

    @torch.no_grad()  #### add weight decay on S
    def add_weight_decay(self, model):

        if issubclass(type(model), DLRTModule) and model.dlrt and not hasattr(model,'S_hat'):
            if model.C.grad is not None:
                r1, r2, r3, r4 = model.dynamic_rank
                model.C.grad[:r1, :r2, :r3, :r4].add_(model.C.grad[:r1, :r2, :r3, :r4], alpha=self.kw['wd'])
                if model.bias is not None:
                    model.bias.grad.add_(model.bias, alpha=self.kw['wd'])
        elif issubclass(type(model), DLRTModule) and model.dlrt and hasattr(model,'S_hat'):
            if model.S_hat.grad is not None:
                r1 = model.dynamic_rank
                model.S_hat.grad[:r1, :r1].add_(model.S_hat.grad[:r1, :r1], alpha=self.kw['wd'])
                if model.bias is not None:
                    model.bias.grad.add_(model.bias, alpha=self.kw['wd'])
        elif hasattr(model, 'weight'):
            if model.weight is not None and model.weight.grad is not None:
                model.weight.grad.add_(model.weight, alpha=self.kw['wd'])
            if model.bias is not None and model.bias.grad is not None:
                model.bias.grad.add_(model.bias, alpha=self.kw['wd'])
        else:
            for name, child in model.named_children():
                self.add_weight_decay(child)

    @torch.no_grad()
    def K_preprocess_step(self):

        if not self.multi_thread:
            self.NN.K_preprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_K_preprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                # res.get()
                # print('passed')
                pool.close()
                pool.join()

    @torch.no_grad()
    def L_preprocess_step(self):

        if not self.multi_thread:
            self.NN.L_preprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_L_preprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                # res.get()
                pool.close()
                pool.join()

    @torch.no_grad()
    def S_preprocess_step(self):

        if not self.multi_thread:
            self.NN.S_preprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_S_preprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                # res.get()
                pool.close()
                pool.join()

    @torch.no_grad()
    def preprocess_step(self):

        self.K_preprocess_step()
        self.L_preprocess_step()

    # postprocess steps

    @torch.no_grad()
    def K_postprocess_step(self):

        if not self.multi_thread:
            self.NN.K_postprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_K_postprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                # res.get()
                pool.close()
                pool.join()

    @torch.no_grad()
    def L_postprocess_step(self):

        if not self.multi_thread:
            self.NN.L_postprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_L_postprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                # res.get()
                pool.close()
                pool.join()

    @torch.no_grad()
    def S_postprocess_step(self):

        if not self.multi_thread:
            self.NN.S_postprocess_step()
        else:
            with Pool(self.n_workers) as pool:
                res = pool.map(self._single_layer_S_postprocess_step, self.NN.lr_model,
                               chunksize=ceil(self.n_workers / len(self.NN.lr_model)))
                pool.close()
                pool.join()

    def update_Q(self):

        self.NN.update_Q()

    @torch.no_grad()
    def postprocess_step(self):

        self.K_postprocess_step()
        self.L_postprocess_step()

    # integration

    @torch.no_grad()
    def zero_bias_grad(self):
        for name, param in self.NN.lr_model.named_parameters():
            if param.requires_grad:
                param.grad = None

    @torch.no_grad()
    def K_integration_step(self):

        self.zero_bias_grad()
        self.integrator.step()

    @torch.no_grad()
    def L_integration_step(self):

        self.integrator.step()
        self.integrator.zero_grad(set_to_none=True)

    @torch.no_grad()
    def K_and_L_integration_step(self):

        # self.zero_bias_grad()
        self.integrator.step()
        self.integrator.zero_grad(set_to_none=True)

    @torch.no_grad()
    def S_integration_step(self, lmm=False):

        '''
        Core (S) integration step. Allows to use Adams-Bashforth LMM method
        '''

        if not lmm:

            self.integrator.step()

        elif lmm and len(self.S_grad_cum) < 2:
            self.integrator.step()
            self.S_grad_cum.append(dict(
                [(n, p.grad.data.clone()) for n, p in self.NN.lr_model.named_parameters() if 'C' in n or 'S_hat' in n]))
            self.S_grad_cum.append(dict(
                [(n, p.grad.data.clone()) for n, p in self.NN.lr_model.named_parameters() if 'C' in n or 'S_hat' in n]))
            self.integrator.zero_grad()
        elif lmm and len(self.S_grad_cum) >= 2:  # da vedere in futuro se può permettere di usare time step più grande
            for n, p in self.NN.lr_model.named_parameters():
                # print(self.S_grad_cum)
                if 'C' in n or 'S_hat' in n:
                    p.add(3 * self.S_grad_cum[1][n] - self.S_grad_cum[0][n], alpha=-0.5 * self.kw['lr'])
                    self.S_grad_cum[0][n] = self.S_grad_cum[1][n]
                    self.S_grad_cum[1][n] = p.grad.data.clone()

    @torch.no_grad()
    def normalize_core(self):  # just a test

        for l in self.NN.lr_model:

            if hasattr(l, 'dlrt') and l.dlrt:
                l.C.mul_(1. / torch.norm(l.C))



    @torch.no_grad()
    def step(self, closure=None):

        """
        optimizer step for the dlrt.
        INPUTS:
        closure : function to compute the loss and backpropagate a second time (Pytorch standard)
        """
        if not self.baseline:
            self.clip_grad()
            self.K_and_L_integration_step()
            for l in self.NN.lr_modules:
                l.K_postprocess_step()
                l.L_postprocess_step()
                l.S_preprocess_step()
            if closure is not None:
                with torch.set_grad_enabled(True):
                    loss = closure()
                    loss.backward()
            self.add_weight_decay(self.NN.lr_model)
            self.clip_grad()
            self.S_integration_step()
            self.integrator.zero_grad(set_to_none=True)
            # self.normalize_core()
            for l in self.NN.lr_modules:
                l.S_postprocess_step()
                l.K_preprocess_step()  
                l.L_preprocess_step()  
                l.update_Q()
        else:
            if closure is not None:
                with torch.set_grad_enabled(True):
                    loss = closure()
                    loss.backward()
            self.clip_grad()
            self.integrator.step()

    @torch.no_grad()
    def zero_grad(self):
        self.integrator.zero_grad(set_to_none=True)

    @torch.no_grad()
    def clip_grad(self):
        '''
        gradient clipping
        '''
        for p in self.NN.lr_model.parameters():
            if p.grad != None:
                p.grad.data.mul_(1. / (torch.norm(p.grad)))

    ######### fine-tuning for tensor version

    @torch.no_grad()
    def activate_S_fine_tuning(self):

        '''
        activate fine-tuning phase for DLRT (train only the core)
        '''
        if not self.baseline:
            params = []

            for l in self.NN.lr_model:

                if hasattr(l, 'dlrt') and l.dlrt:

                    for i in range(len(l.Ks)):
                        l.Ks[i].requires_grad = False
                        l.Us[i] = torch.nn.Parameter(l.Us[i][:, :l.dynamic_rank[i]], requires_grad=False)
                    l.C = torch.nn.Parameter(
                        l.C[:l.dynamic_rank[0], :l.dynamic_rank[1], :l.dynamic_rank[2], :l.dynamic_rank[3]])
                    l.tau = 0.0
                    params.append(l.C)
                    if l.bias is not None:
                        params.append(l.bias)
                elif hasattr(l, 'lr') and not l.lr:
                    params.append(l.weight)
                    if l.bias is not None:
                        params.append(l.bias)
            self.NN.set_step('test')
            params = torch.nn.ParameterList(params)
            self.integrator = self.KLS_optim(params, **self.kw)

    @torch.no_grad()
    def S_finetune_step(self):

        '''
        fine-tuning optimization step on the core
        '''
        # self.clip_grad()
        self.integrator.step()
       
            

